#!/usr/bin/env python
# MIT License
#
# Deep Joint Demosaicking and Denoising
# Siggraph Asia 2016
# Michael Gharbi, Gaurav Chaurasia, Sylvain Paris, Fredo Durand
# 
# Copyright (c) 2016 Michael Gharbi
# 
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Utility to convert a folder containing images to a lmdb database usable by Caffe."""

import argparse
import os
import re
import shutil

import caffe
import lmdb
import numpy as np
from PIL import Image

DB_MAP_SIZE = 1099511627776
DB_KEY_BYTES = 4
PATCH_SIZE = 128

def main(args):
    regexp = re.compile(r'.*\.png')
    env = lmdb.open(args.output, map_size = DB_MAP_SIZE)
    n = 1
    invalid = []
    for d, dirs, files in os.walk(args.input):
        txn = env.begin(write=True)
        for f in files:
            if regexp.match(f):
                print 'image', n, f
                try:
                    im = Image.open(os.path.join(d, f))
                except IOError:
                    print '  could not read f'
                    invalid.append(os.path.join(d, f))
                    continue
                w, h = im.size
                if w != PATCH_SIZE or h != PATCH_SIZE:
                    raise ValueError("Patch size should be {}, got {}x{}".format(
                                PATCH_SIZE, w,h))

                im = np.array(im)
                h, w, c = im.shape
                if c == 1:
                    raise ValueError("Does not accept monochromatic images.")

                if c == 4:
                    im = im[:, :, :3]

                im = im.transpose((2,0,1))

                key = np.random.bytes(DB_KEY_BYTES)
                datum = caffe.io.array_to_datum(im)
                txn.put(key, datum.SerializeToString())

                n += 1
        txn.commit()
    env.close()
    print invalid
    print len(invalid), 'invalid'
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', default="data/images/train", type=str, help='path to the input folder containing images.')
    parser.add_argument('--output', default="data/db_train", type=str, help='target directory for the lmdb database.')
    args = parser.parse_args()
    main(args)
